"""
Copyright (c) 2025 by FlashInfer team.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

  http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import numpy as np
import pytest
import torch

import flashinfer
from tests.test_helpers.jit_utils import (
    gen_persistent_batch_attention_modules,
    gen_prefill_attention_modules,
)
from flashinfer.utils import get_compute_capability, has_flashinfer_jit_cache


@pytest.fixture(
    autouse=not has_flashinfer_jit_cache(),
    scope="module",
)
def warmup_jit():
    flashinfer.jit.build_jit_specs(
        gen_persistent_batch_attention_modules(
            [torch.float16, torch.bfloat16],  # q_dtypes
            [torch.float16, torch.bfloat16],  # kv_dtypes
            [64, 128, 256],  # head_dims
            [False, True],  # use_logits_soft_cap
        )
        + gen_prefill_attention_modules(
            [torch.float16, torch.bfloat16],  # q_dtypes
            [torch.float16, torch.bfloat16],  # kv_dtypes
            [64, 128, 256],  # head_dims
            [0],  # pos_encoding_modes
            [False],  # use_sliding_windows
            [False, True],  # use_logits_soft_caps
            [False],  # use_fp16_qk_reductions
        ),
        verbose=False,
    )


# -------------------------  Configuration generation function  ----------------------------- #
def _build_seq_len_configs():
    """
    Reproduce the sequence length configurations from the original benchmark (including random cases).
    Returns: List[List[Tuple[int,int]]]  -> Each element is a list of (kv_len, qo_len) pairs.
    """
    np.random.seed(42)
    torch.manual_seed(42)

    seq_len_configs = [
        [(146, 146)],
        [(67, 67)],
        [(8190, 7939)],
        [(2048, 1)] * 77,  # decode-only
        [(4099, 129)] * 2,  # prefill-only
        [(600, 1)] * 132 * 2 + [(5000, 3)] * 128,
        [(1024, 1)] * 100 + [(8192, 17)] * 8,  # speculative decode
        [(766, 2)] * 99 + [(1024, 512)] * 1,  # chunked prefill
        [(2, 235)] + [(1, 13353)],  # real workload
    ]

    # Construct random seqlen tests
    bsz, stride, sparsity = 256, 16, 0.05
    full_kv_len = np.random.randint(1000, 11000, size=bsz)
    seq_len = []
    for i in range(bsz):
        if i % stride == 0:
            kv_len, qo_len = full_kv_len[i], stride + 1
        else:
            kv_len, qo_len = int(full_kv_len[i] * sparsity), 1
        seq_len.append((kv_len, qo_len))
    seq_len_configs.append(seq_len)

    return seq_len_configs


def _run_attention(
    kv_lens,
    qo_lens,
    page_block_size=1,
    num_kv_heads=1,
    num_qo_heads=1,
    head_dim=128,
    v_scale=None,
    layout="NHD",
    test_dtype=torch.bfloat16,
    logits_soft_cap=0.0,
    device="cuda",
    causal=True,
    is_chunked_q=False,
):
    """
    Run both implementations and return (output_old, lse_old, output_new, lse_new)
    """
    dev = torch.device(device)
    seq_lens = torch.tensor(kv_lens, dtype=torch.int32, device=dev)
    q_lens = torch.tensor(qo_lens, dtype=torch.int32, device=dev)

    seq_lens_blocks = torch.ceil(seq_lens / page_block_size).int()

    q_indptr = torch.cat(
        [torch.tensor([0], device=dev), torch.cumsum(q_lens, 0)], dim=0
    ).int()
    kv_indptr = torch.cat(
        [torch.tensor([0], device=dev), torch.cumsum(seq_lens_blocks, 0)], dim=0
    ).int()

    num_blocks = kv_indptr[-1].item()

    if is_chunked_q:
        q_base = torch.rand(
            q_indptr[-1].item(),
            num_qo_heads,
            head_dim * 2,
            dtype=test_dtype,
            device=dev,
        )
        q = torch.chunk(q_base, 2, dim=-1)[0]
    else:
        q = torch.rand(
            q_indptr[-1].item(), num_qo_heads, head_dim, dtype=test_dtype, device=dev
        )
    if layout == "NHD":
        kv_data = torch.randn(
            num_blocks,
            2,
            page_block_size,
            num_kv_heads,
            head_dim,
            dtype=test_dtype,
            device=dev,
        )
    elif layout == "HND":
        kv_data = torch.randn(
            num_blocks,
            2,
            num_kv_heads,
            page_block_size,
            head_dim,
            dtype=test_dtype,
            device=dev,
        )

    # --------- old scheduler --------- #
    wrapper_old = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
        torch.empty(256 * 1024 * 1024, dtype=torch.uint8, device=dev),
        kv_layout=layout,
        backend="fa2",
    )
    last_page_len = (seq_lens - 1) % page_block_size + 1
    wrapper_old.plan(
        q_indptr,
        kv_indptr,
        torch.arange(num_blocks, device=dev).int(),
        last_page_len,
        num_qo_heads,
        num_kv_heads,
        head_dim,
        page_block_size,
        causal=causal,
        q_data_type=test_dtype,
        kv_data_type=test_dtype,
        logits_soft_cap=logits_soft_cap,
    )
    out_old, lse_old = wrapper_old.run(q, kv_data, return_lse=True, v_scale=v_scale)

    # --------- new / mixed scheduler --------- #
    wrapper = flashinfer.BatchAttention(kv_layout=layout)
    wrapper.plan(
        q_indptr,
        kv_indptr,
        torch.arange(num_blocks, device=dev).int(),
        seq_lens,
        num_qo_heads,
        num_kv_heads,
        head_dim,
        head_dim,
        page_block_size,
        causal=causal,
        q_data_type=test_dtype,
        kv_data_type=test_dtype,
        logits_soft_cap=logits_soft_cap,
    )
    out_new, lse_new = wrapper.run(
        q, kv_data, v_scale=v_scale, logits_soft_cap=logits_soft_cap
    )

    torch.cuda.synchronize()
    torch.testing.assert_close(out_old, out_new, rtol=1e-2, atol=1e-2)
    torch.testing.assert_close(lse_old, lse_new, rtol=1e-2, atol=1e-2)


# -------------------------  PyTest test case  ----------------------------- #
@pytest.mark.xfail(
    get_compute_capability(torch.device(device="cuda"))[0] == 12,
    reason="Expected failure for SM120/121 for now since the tile size/number of stages is too large.",
)
def test_batch_attention_with_noncontiguous_q():
    # Pick the first sequence length config's first pair
    seq_len_pairs = _build_seq_len_configs()[0]
    kv_lens = [p[0] for p in seq_len_pairs]
    qo_lens = [p[1] for p in seq_len_pairs]

    # Fixed single-case parameters
    page_block_size = 1
    num_kv_heads = 1
    gqa_group_size = 1
    num_qo_heads = num_kv_heads * gqa_group_size
    head_dim = 64
    test_dtype = torch.bfloat16
    layout = "NHD"
    logits_soft_cap = 0.0
    v_scale = None
    causal = True

    _run_attention(
        kv_lens=kv_lens,
        qo_lens=qo_lens,
        page_block_size=page_block_size,
        num_kv_heads=num_kv_heads,
        num_qo_heads=num_qo_heads,
        head_dim=head_dim,
        v_scale=v_scale,
        causal=causal,
        layout=layout,
        test_dtype=test_dtype,
        logits_soft_cap=logits_soft_cap,
        device="cuda",
        is_chunked_q=True,
    )


@pytest.mark.xfail(
    get_compute_capability(torch.device(device="cuda"))[0] == 12,
    reason="Expected failure for SM120/121 for now since the tile size/number of stages is too large.",
)
@pytest.mark.parametrize("seq_len_pairs", _build_seq_len_configs())
@pytest.mark.parametrize("page_block_size", [1, 8, 16])
@pytest.mark.parametrize("num_kv_heads", [1, 4])
@pytest.mark.parametrize("gqa_group_size", [1, 4, 7, 8])
@pytest.mark.parametrize("head_dim", [64, 128, 256])
@pytest.mark.parametrize("v_scale", [2.0, None])
@pytest.mark.parametrize("causal", [False, True])
@pytest.mark.parametrize("layout", ["HND", "NHD"])
@pytest.mark.parametrize("test_dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("logits_soft_cap", [0.0, 50.0])
def test_batch_attention_correctness(
    seq_len_pairs,
    page_block_size,
    num_kv_heads,
    gqa_group_size,
    head_dim,
    v_scale,
    causal,
    layout,
    test_dtype,
    logits_soft_cap,
):
    num_qo_heads = num_kv_heads * gqa_group_size
    kv_lens = [p[0] for p in seq_len_pairs]
    qo_lens = [p[1] for p in seq_len_pairs]

    _run_attention(
        kv_lens=kv_lens,
        qo_lens=qo_lens,
        page_block_size=page_block_size,
        num_kv_heads=num_kv_heads,
        num_qo_heads=num_qo_heads,
        head_dim=head_dim,
        v_scale=v_scale,
        causal=causal,
        layout=layout,
        test_dtype=test_dtype,
        logits_soft_cap=logits_soft_cap,
        device="cuda",
    )


if __name__ == "__main__":
    test_batch_attention_correctness(
        seq_len_pairs=[(1000, 1000)],
        page_block_size=1,
        num_kv_heads=4,
        gqa_group_size=7,
        head_dim=128,
        causal=True,
        layout="NHD",
        test_dtype=torch.bfloat16,
        logits_soft_cap=0.0,
    )
